{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0b766fe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导包\n",
    "import torch\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "import torch.nn as nn \n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f32c4d1b",
   "metadata": {},
   "source": [
    "## data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "507db569",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw/train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a1d2c56523614f5195807d150fa17419",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/9912422 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9a22149088454e72bf9b62f2778dc164",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/28881 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b99ea20af1ee4ad096cf1a4a6a6d60d8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1648877 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cea73404267a4e65a32fe0ae96f8747b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4542 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train_data = datasets.MNIST(root=\"data/mnist\",train=True,transform=transforms.ToTensor(),download=True)\n",
    "test_data = datasets.MNIST(root=\"data/mnist\",train=False,transform=transforms.ToTensor(),download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1d01ab34",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 100\n",
    "train_loader = torch.utils.data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(dataset=test_data,batch_size=batch_size,shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3862ea5d",
   "metadata": {},
   "source": [
    "## net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4f86e0eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义 MLP 网络  继承nn.Module\n",
    "class MLP(nn.Module):\n",
    "    \n",
    "    # 初始化方法\n",
    "    # input_size输入数据的维度    \n",
    "    # hidden_size 隐藏层的大小\n",
    "    # num_classes 输出分类的数量\n",
    "    def __init__(self, input_size, hidden_size, num_classes):\n",
    "        # 调用父类的初始化方法\n",
    "        super(MLP, self).__init__()\n",
    "        # 定义第1个全连接层  \n",
    "        self.fc1 = nn.Linear(input_size, hidden_size)\n",
    "        # 定义激活函数\n",
    "        self.relu = nn.ReLU()\n",
    "        # 定义第2个全连接层\n",
    "        self.fc2 = nn.Linear(hidden_size, hidden_size)\n",
    "        # 定义第3个全连接层\n",
    "        self.fc3 = nn.Linear(hidden_size, num_classes)\n",
    "        \n",
    "    # 定义forward函数\n",
    "    # x 输入的数据\n",
    "    def forward(self, x):\n",
    "        # 第一层运算\n",
    "        out = self.fc1(x)\n",
    "        # 将上一步结果送给激活函数\n",
    "        out = self.relu(out)\n",
    "        # 将上一步结果送给fc2\n",
    "        out = self.fc2(out)\n",
    "        # 同样将结果送给激活函数\n",
    "        out = self.relu(out)\n",
    "        # 将上一步结果传递给fc3\n",
    "        out = self.fc3(out)\n",
    "        # 返回结果\n",
    "        return out\n",
    "    \n",
    "# 定义参数    \n",
    "input_size = 28 * 28  # 输入大小\n",
    "hidden_size = 512  # 隐藏层大小\n",
    "num_classes = 10  # 输出大小（类别数） \n",
    "\n",
    "# 初始化MLP    \n",
    "model = MLP(input_size, hidden_size, num_classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d3a1ae6",
   "metadata": {},
   "source": [
    "## loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6f7329fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cd81a0b",
   "metadata": {},
   "source": [
    "## optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f1d15bf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = 0.001 # 学习率\n",
    "optimizer = optim.Adam(model.parameters(),lr=learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dd0df25",
   "metadata": {},
   "source": [
    "## training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ec0d291a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/10], Step [100/600], Loss: 0.3697\n",
      "Epoch [1/10], Step [200/600], Loss: 0.1534\n",
      "Epoch [1/10], Step [300/600], Loss: 0.1699\n",
      "Epoch [1/10], Step [400/600], Loss: 0.0657\n",
      "Epoch [1/10], Step [500/600], Loss: 0.1864\n",
      "Epoch [1/10], Step [600/600], Loss: 0.0878\n",
      "Epoch [2/10], Step [100/600], Loss: 0.0853\n",
      "Epoch [2/10], Step [200/600], Loss: 0.0340\n",
      "Epoch [2/10], Step [300/600], Loss: 0.1702\n",
      "Epoch [2/10], Step [400/600], Loss: 0.0413\n",
      "Epoch [2/10], Step [500/600], Loss: 0.0730\n",
      "Epoch [2/10], Step [600/600], Loss: 0.0986\n",
      "Epoch [3/10], Step [100/600], Loss: 0.0139\n",
      "Epoch [3/10], Step [200/600], Loss: 0.0562\n",
      "Epoch [3/10], Step [300/600], Loss: 0.0235\n",
      "Epoch [3/10], Step [400/600], Loss: 0.0731\n",
      "Epoch [3/10], Step [500/600], Loss: 0.0398\n",
      "Epoch [3/10], Step [600/600], Loss: 0.1915\n",
      "Epoch [4/10], Step [100/600], Loss: 0.0118\n",
      "Epoch [4/10], Step [200/600], Loss: 0.0911\n",
      "Epoch [4/10], Step [300/600], Loss: 0.0256\n",
      "Epoch [4/10], Step [400/600], Loss: 0.0879\n",
      "Epoch [4/10], Step [500/600], Loss: 0.0045\n",
      "Epoch [4/10], Step [600/600], Loss: 0.0191\n",
      "Epoch [5/10], Step [100/600], Loss: 0.0073\n",
      "Epoch [5/10], Step [200/600], Loss: 0.0125\n",
      "Epoch [5/10], Step [300/600], Loss: 0.0421\n",
      "Epoch [5/10], Step [400/600], Loss: 0.0424\n",
      "Epoch [5/10], Step [500/600], Loss: 0.0099\n",
      "Epoch [5/10], Step [600/600], Loss: 0.0043\n",
      "Epoch [6/10], Step [100/600], Loss: 0.0086\n",
      "Epoch [6/10], Step [200/600], Loss: 0.0070\n",
      "Epoch [6/10], Step [300/600], Loss: 0.0092\n",
      "Epoch [6/10], Step [400/600], Loss: 0.0152\n",
      "Epoch [6/10], Step [500/600], Loss: 0.0071\n",
      "Epoch [6/10], Step [600/600], Loss: 0.0038\n",
      "Epoch [7/10], Step [100/600], Loss: 0.0414\n",
      "Epoch [7/10], Step [200/600], Loss: 0.0159\n",
      "Epoch [7/10], Step [300/600], Loss: 0.0332\n",
      "Epoch [7/10], Step [400/600], Loss: 0.0054\n",
      "Epoch [7/10], Step [500/600], Loss: 0.0067\n",
      "Epoch [7/10], Step [600/600], Loss: 0.0072\n",
      "Epoch [8/10], Step [100/600], Loss: 0.0030\n",
      "Epoch [8/10], Step [200/600], Loss: 0.0046\n",
      "Epoch [8/10], Step [300/600], Loss: 0.0492\n",
      "Epoch [8/10], Step [400/600], Loss: 0.0126\n",
      "Epoch [8/10], Step [500/600], Loss: 0.0592\n",
      "Epoch [8/10], Step [600/600], Loss: 0.0073\n",
      "Epoch [9/10], Step [100/600], Loss: 0.0520\n",
      "Epoch [9/10], Step [200/600], Loss: 0.0031\n",
      "Epoch [9/10], Step [300/600], Loss: 0.0036\n",
      "Epoch [9/10], Step [400/600], Loss: 0.0077\n",
      "Epoch [9/10], Step [500/600], Loss: 0.0097\n",
      "Epoch [9/10], Step [600/600], Loss: 0.0029\n",
      "Epoch [10/10], Step [100/600], Loss: 0.0002\n",
      "Epoch [10/10], Step [200/600], Loss: 0.0021\n",
      "Epoch [10/10], Step [300/600], Loss: 0.0235\n",
      "Epoch [10/10], Step [400/600], Loss: 0.0004\n",
      "Epoch [10/10], Step [500/600], Loss: 0.0343\n",
      "Epoch [10/10], Step [600/600], Loss: 0.0249\n"
     ]
    }
   ],
   "source": [
    "# 训练网络\n",
    "\n",
    "num_epochs = 10 # 训练轮数\n",
    "for epoch in range(num_epochs):\n",
    "    for i, (images, labels) in enumerate(train_loader):\n",
    "        # 将iamges转成向量\n",
    "        images = images.reshape(-1, 28 * 28)\n",
    "        # 将数据送到网络中\n",
    "        outputs = model(images)\n",
    "        # 计算损失\n",
    "        loss = criterion(outputs, labels)\n",
    "        \n",
    "        # 首先将梯度清零\n",
    "        optimizer.zero_grad()\n",
    "        # 反向传播\n",
    "        loss.backward()\n",
    "        # 更新参数\n",
    "        optimizer.step()\n",
    "        \n",
    "        if (i + 1) % 100 == 0:\n",
    "            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e3ffc84",
   "metadata": {},
   "source": [
    "## test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "565c0019",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the network on the 10000 test images: 98.01 %\n"
     ]
    }
   ],
   "source": [
    "# 测试网络\n",
    "with torch.no_grad():\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    # 从 test_loader中循环读取测试数据\n",
    "    for images, labels in test_loader:\n",
    "        # 将images转成向量\n",
    "        images = images.reshape(-1, 28 * 28)\n",
    "        # 将数据送给网络\n",
    "        outputs = model(images)\n",
    "        # 取出最大值对应的索引  即预测值\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        # 累加label数\n",
    "        total += labels.size(0)\n",
    "        # 预测值与labels值比对 获取预测正确的数量\n",
    "        correct += (predicted == labels).sum().item()\n",
    "    # 打印最终的准确率\n",
    "    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec351cb8",
   "metadata": {},
   "source": [
    "## save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b1f14822",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model,\"mnist_mlp_model.pkl\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
